Skip to content

Support Metal as device#108

Open
haakon-e wants to merge 3 commits intohe/fix/add-FT-to-gpu-testsfrom
he/metal
Open

Support Metal as device#108
haakon-e wants to merge 3 commits intohe/fix/add-FT-to-gpu-testsfrom
he/metal

Conversation

@haakon-e
Copy link
Copy Markdown
Member

@haakon-e haakon-e commented Mar 13, 2025

Purpose

This is a speculative PR for adding Metal as a supported device type.
Merging this PR is not useful for simulations that depend on ClimaCore, since that package currently only supports CUDA. But, other packages like CloudMicrophysics may enjoy local GPU-enabled development without the need for an external server-based GPU. The broader Julia ecosystem may also enjoy this contribution, as this package is useful for non-CliMA projects too.

Content

  • Added ClimaCommsMetalExt extension that largely mirrors the ClimaCommsCUDAExt extension.
  • updated docs, including README.md, apis.md, faqs.md, index.md, internals.md, with description of the proposed Metal support.
  • Added MetalDevice struct, and extended methods like device_type and device to make use of it.
  • Added metal_ext_is_loaded and metal_is_required to look for Metal support, and extended the @import_required_backends macro to load Metal when applicable.
  • Updated tests, including:
    • arrays in hygiene.jl are now all Float32, which is required for Metal
    • runtests.jl is successful on Metal, but only Float32 is tested. Any Float64 tests are skipped.
  • Added test_cuda.jl and test_metal.jl for convenient execution of tests on the respective devices. Can remove these files if not desired.
    • Note: test_metal.jl is useful for manual testing since, as far as I know, we don't have easy access to a remote test server with a Metal-compatible GPU. Given this, it may be useful to label Metal support as "experimental" or "untested" or similar.

  • I have read and checked the items on the review checklist.

@haakon-e haakon-e requested review from Sbozzolo, charleskawczynski and dennisYatunin and removed request for Sbozzolo March 13, 2025 02:10
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented May 30, 2025

Walkthrough

This update introduces support for Apple Metal GPUs in the ClimaComms package. The MetalDevice type is added, along with logic to detect and use Metal as a backend. The package configuration, documentation, and tests are updated to recognize Metal as a valid device option. A new extension module implements Metal-specific methods for device management, array adaptation, and core computational operations. Conditional logic ensures that Metal is only used if the backend is available. Test scripts and documentation are expanded to cover Metal usage and requirements, mirroring existing support for CUDA devices.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
test/hygiene.jl (1)

14-14: Specify Float32 in rand calls
Using rand(Float32, …) ensures compatibility with Metal arrays. Consider refactoring repeated patterns into a helper function to reduce duplication.

Also applies to: 19-19, 24-24, 29-29, 37-37, 42-42, 48-48, 53-53

docs/src/faqs.md (1)

78-78: Fix markdown formatting issue.

Remove spaces inside the code span elements.

-so you might install packages like ` MPI.jl` and `CUDA.jl` (or `Metal.jl`).
+so you might install packages like `MPI.jl` and `CUDA.jl` (or `Metal.jl`).
🧰 Tools
🪛 markdownlint-cli2 (0.17.2)

78-78: Spaces inside code span elements
null

(MD038, no-space-in-code)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a2bf168 and 192f8f6.

📒 Files selected for processing (14)
  • Project.toml (1 hunks)
  • README.md (1 hunks)
  • docs/src/apis.md (2 hunks)
  • docs/src/faqs.md (3 hunks)
  • docs/src/index.md (1 hunks)
  • docs/src/internals.md (2 hunks)
  • ext/ClimaCommsMetalExt.jl (1 hunks)
  • src/devices.jl (4 hunks)
  • src/loading.jl (2 hunks)
  • test/Project.toml (1 hunks)
  • test/hygiene.jl (1 hunks)
  • test/runtests.jl (8 hunks)
  • test/test_cuda.jl (1 hunks)
  • test/test_metal.jl (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
`*`: # CodeRabbit Style Guide (CliMA Inspired)

Leverage CodeRabbit for code reviews aligning with CliMA's practices.

I. Key Areas for CodeRabbit:

  • Style: Naming (Titl...

*: # CodeRabbit Style Guide (CliMA Inspired)

Leverage CodeRabbit for code reviews aligning with CliMA's practices.

I. Key Areas for CodeRabbit:

  • Style: Naming (TitleCase, lowercase_with_underscores), line length (<78), indentation (4 spaces), import order.
  • DRY: Flag duplicated code; encourage modularization.
  • Docstrings: Flag missing docstrings for modules, structs, functions.
  • Tests: Detect missing unit tests (if configured).
  • Complexity: Report on cyclomatic complexity.

II. Conventions (CodeRabbit Can Help):

  • Naming: Follow CliMA/CMIP conventions. Avoid l, O, I as single-char vars.
  • Unicode: Human review needed. Acceptable chars.

III. Documentation (CodeRabbit Flags Deficiencies):

  • Docstrings: Describe purpose, inputs, outputs, examples.

IV. Testing (CodeRabbit: Missing Tests):

  • Unit tests required for new/modified code.

V. CodeRabbit Config:

  • .coderabbit.yaml: Customize language, line length, tests.

VI. Workflow:

  • Review CodeRabbit's suggestions; investigate, provide context, address issues, improve config.

VII. Abandoned Code: Delete or clearly mark as deprecated.

VIII. Reviewer Responsibilities: Ensure adherence to best practices.

  • README.md
  • Project.toml
🪛 markdownlint-cli2 (0.17.2)
docs/src/faqs.md

78-78: Spaces inside code span elements
null

(MD038, no-space-in-code)

⏰ Context from checks skipped due to timeout of 90000ms (3)
  • GitHub Check: test-os (macos-latest)
  • GitHub Check: docbuild
  • GitHub Check: test-os (ubuntu-latest)
🔇 Additional comments (32)
test/Project.toml (1)

7-7: Add Metal dependency
Metal is now listed under [deps], aligning with the new backend support.

README.md (1)

18-18: Document MetalDevice support
The MetalDevice bullet is correctly added alongside existing devices.

docs/src/index.md (1)

23-23: Include MetalDevice in documentation
MetalDevice entry matches the list of supported devices and macros.

test/test_cuda.jl (1)

1-13: New CUDA test runner
Script properly configures the environment and includes the main test suite.

Project.toml (2)

19-19: LGTM - Extension and compatibility setup.

The extension registration and Metal v1 compatibility constraint follow Julia package conventions correctly.

Also applies to: 27-27


14-14:

Details

✅ Verification successful

Verify Metal.jl UUID.

Ensure the UUID matches the official Metal.jl package registration.


🏁 Script executed:

#!/bin/bash
# Check if the Metal.jl UUID is correct
curl -s "https://raw.githubusercontent.com/JuliaRegistries/General/master/M/Metal/Package.toml" | grep "uuid"

Length of output: 153


Metal.jl UUID Verified
The UUID “dde4c033-4e86-420c-a63e-0dd931031962” in Project.toml matches the official Metal.jl registry entry. No further action required.

test/test_metal.jl (2)

7-13: LGTM - Follows established test pattern.

The test structure mirrors test_cuda.jl appropriately, maintaining consistency across device-specific test runners.


4-5:

Details

✅ Verification successful

Environment variable naming consistency.

Verify these environment variable names match the expected conventions used throughout the codebase.


🏁 Script executed:

#!/bin/bash
# Check for consistent environment variable usage in the codebase
rg "CLIMACOMMS_DEVICE|CLIMACOMMS_TEST_DEVICE" --type julia -A 2 -B 2

Length of output: 2497


Environment variable naming consistency verified.
The CLIMACOMMS_DEVICE and CLIMACOMMS_TEST_DEVICE settings in test/test_metal.jl match the patterns in src/devices.jl, test/test_cuda.jl, and test/runtests.jl. No further changes needed.

docs/src/apis.md (1)

16-16:

Details

✅ Verification successful

Verify documented functions exist.

Ensure metal_is_required and MetalDevice are actually implemented in the codebase.

Also applies to: 28-28


🏁 Script executed:

#!/bin/bash
# Check that documented Metal functions and types exist
rg -A 3 "metal_is_required|MetalDevice" --type julia

Length of output: 7714


Documentation additions verified

Both metal_is_required and MetalDevice are implemented in the codebase:

  • metal_is_required() in src/loading.jl
  • struct MetalDevice in src/devices.jl

Approving these documentation updates.

docs/src/internals.md (2)

13-13: LGTM - Device list update.

MetalDevice appropriately added to the device types list.


121-125: LGTM - Comprehensive Metal backend documentation.

The Metal backend description follows the established pattern and provides clear information about the Metal.jl dependency requirement.

docs/src/faqs.md (2)

16-25: LGTM: Clear Metal device instructions added.

Documentation properly mirrors the CUDA device setup instructions.


91-91: LGTM: Consistent Metal device summary documentation.

Properly documents Metal device behavior in summary output.

src/loading.jl (3)

34-35: LGTM: Consistent Metal extension loading check.

Mirrors the CUDA pattern perfectly.


40-51: LGTM: Well-documented Metal requirement function.

Documentation and implementation are consistent with cuda_is_required().


58-58: LGTM: Macro properly extended for Metal support.

Conditional Metal import follows the established pattern.

Also applies to: 70-73

src/devices.jl (4)

39-44: LGTM: Well-documented MetalDevice struct.

Proper subtyping and clear documentation for M-series chips.


66-67: LGTM: Consistent Metal device type detection.

Follows the established pattern for device environment variable parsing.


83-83: LGTM: Documentation updated for Metal device.

Properly includes Metal in the allowed device values.


93-96: LGTM: Proper Metal backend validation.

Error handling mirrors CUDA implementation with appropriate error message.

test/runtests.jl (6)

23-24: LGTM: Consistent device type testing.

Properly tests MetalDevice instance detection.


40-43: LGTM: Appropriate Float64 exclusions for Metal.

Metal devices don't support Float64, so these skips are necessary and well-logged.

Also applies to: 111-114, 174-177, 192-195


237-240: LGTM: Consistent Float64 broadcast exclusion.

Properly excludes Float64 broadcast test for Metal devices.


244-244: LGTM: Float32 used for Metal compatibility.

Using Float32 ensures the test works across all device types.


249-250: LGTM: Consistent scalar access error testing.

Properly includes Metal alongside CUDA for expected scalar access errors.


302-309: LGTM: Comprehensive Adapt.jl testing for Metal.

Tests mirror CUDA Adapt.jl tests and ensure proper Metal array adaptation.

Also applies to: 326-339

ext/ClimaCommsMetalExt.jl (6)

1-8: LGTM! Clean module structure.

Standard Julia extension setup with appropriate imports.


14-17: LGTM! Appropriate no-op for Metal's automatic device management.

The implementation correctly reflects Metal's built-in device assignment behavior.


36-38: LGTM! Correct device availability check.

Simple and effective implementation.


45-60: LGTM! Proper Adapt.jl integration.

Both adaptation methods follow the correct patterns for device and context adaptation.


67-67: LGTM! Correct array type specification.


74-75: LGTM! Proper Metal scalar operation handling.

Comment thread ext/ClimaCommsMetalExt.jl
Comment thread ext/ClimaCommsMetalExt.jl
@haakon-e haakon-e changed the base branch from main to dy/gpu_threaded May 30, 2025 22:36
@haakon-e
Copy link
Copy Markdown
Member Author

haakon-e commented May 30, 2025

This change is part of the following stack:

Change managed by git-spice.

@haakon-e haakon-e removed the request for review from charleskawczynski May 30, 2025 22:38
@haakon-e
Copy link
Copy Markdown
Member Author

haakon-e commented May 30, 2025

All the independent threaded tests run locally

@testset "independent threaded" begin
a = AT(rand(100))
b = AT(rand(100))
is_single_cpu_thread =
device isa ClimaComms.CPUSingleThreaded &&
context isa ClimaComms.SingletonCommsContext
kernel1!(a, b) = ClimaComms.@threaded for i in axes(a, 1)
a[i] = b[i]
end
kernel1!(a, b)
@test a == b
is_single_cpu_thread && @test (@allocated kernel1!(a, b)) == 0
kernel2!(a, b) = ClimaComms.@threaded coarsen=:static for i in axes(a, 1)
a[i] = 2 * b[i]
end
kernel2!(a, b)
@test a == 2 * b
is_single_cpu_thread && @test (@allocated kernel2!(a, b)) == 0
kernel3!(a, b) = ClimaComms.@threaded device coarsen=3 for i in axes(a, 1)
a[i] = 3 * b[i]
end
kernel3!(a, b)
@test a == 3 * b
is_single_cpu_thread && @test (@allocated kernel3!(a, b)) == 0
kernel4!(a, b) = ClimaComms.@threaded device coarsen=400 for i in axes(a, 1)
a[i] = 4 * b[i]
end
kernel4!(a, b)
@test a == 4 * b
is_single_cpu_thread && @test (@allocated kernel4!(a, b)) == 0
kernel5!(a, b) = ClimaComms.@threaded block_size=50 for i in axes(a, 1)
a[i] = 5 * b[i]
end
kernel5!(a, b)
@test a == 5 * b
is_single_cpu_thread && @test (@allocated kernel5!(a, b)) == 0
end

the interdependent threaded example also works

@testset "interdependent threaded" begin
set_deriv_at_point!(output, input, i) =
if i == 1
output[1] = input[2] - input[1]
elseif i == 100
output[100] = input[100] - input[99]
else
output[i] = (input[i + 1] - 2 * input[i] + input[i - 1]) / 2
end
function threaded_deriv_with_respect_to_i(device, input, i)
output =
ClimaComms.static_shared_memory_array(device, eltype(input), 100)
ClimaComms.@sync_interdependent i set_deriv_at_point!(output, input, i)
return output
end
function unthreaded_deriv_with_respect_to_i(device, input)
output = MArray{Tuple{100}, eltype(input)}(undef)
for i in axes(input, 1)
set_deriv_at_point!(output, input, i)
end
return output
end
threaded_deriv3_with_respect_to_i!(device, a, b) =
ClimaComms.@threaded device for i in @interdependent(axes(a, 1))
∂b_∂i = threaded_deriv_with_respect_to_i(device, b, i)
∂²b_∂i² = threaded_deriv_with_respect_to_i(device, ∂b_∂i, i)
∂³b_∂i³ = threaded_deriv_with_respect_to_i(device, ∂²b_∂i², i)
ClimaComms.@sync_interdependent i a[i] = ∂³b_∂i³[i]
end
unthreaded_deriv3_with_respect_to_i!(device, a, b) =
ClimaComms.allowscalar(device) do
∂b_∂i = unthreaded_deriv_with_respect_to_i(device, b)
∂²b_∂i² = unthreaded_deriv_with_respect_to_i(device, ∂b_∂i)
∂³b_∂i³ = unthreaded_deriv_with_respect_to_i(device, ∂²b_∂i²)
for i in axes(a, 1)
a[i] = ∂³b_∂i³[i]
end
end
a_threaded = AT(rand(100))
a_unthreaded = AT(rand(100))
b = AT(rand(100))
is_single_cpu_thread =
device isa ClimaComms.CPUSingleThreaded &&
context isa ClimaComms.SingletonCommsContext
threaded_deriv3_with_respect_to_i!(device, a_threaded, b)
unthreaded_deriv3_with_respect_to_i!(device, a_unthreaded, b)
@test a_threaded == a_unthreaded
# TODO: Figure out source of allocations for interdependent iterators.
threaded_allocations =
@allocated threaded_deriv3_with_respect_to_i!(device, a_threaded, b)
@info "Allocated $threaded_allocations bytes"
is_single_cpu_thread && @test_broken threaded_allocations == 0
end

but the independent and interdependent threaded test does not

@testset "independent and interdependent threaded" begin
set_deriv_at_point!(output, input, i, j...) =
if i == 1
output[1] = input[2, j...] - input[1, j...]
elseif i == 100
output[100] = input[100, j...] - input[99, j...]
else
output[i] =
(input[i + 1, j...] - 2 * input[i, j...] + input[i - 1, j...]) /
2
end
function threaded_deriv_with_respect_to_i(device, input, i, j...)
output =
ClimaComms.static_shared_memory_array(device, eltype(input), 100)
ClimaComms.@sync_interdependent i begin
set_deriv_at_point!(output, input, i, j...)
end
return output
end
function unthreaded_deriv_with_respect_to_i(device, input, j...)
output = MArray{Tuple{100}, eltype(input)}(undef)
for i in axes(input, 1)
set_deriv_at_point!(output, input, i, j...)
end
return output
end
threaded_deriv3_with_respect_to_i!(device, a, b) =
ClimaComms.@threaded device begin
for i in @interdependent(axes(a, 1)), j in axes(a, 2)
∂b_∂i = threaded_deriv_with_respect_to_i(device, b, i, j)
∂²b_∂i² = threaded_deriv_with_respect_to_i(device, ∂b_∂i, i)
∂³b_∂i³ = threaded_deriv_with_respect_to_i(device, ∂²b_∂i², i)
ClimaComms.@sync_interdependent a[i, j] = ∂³b_∂i³[i]
end
end
unthreaded_deriv3_with_respect_to_i!(device, a, b) =
ClimaComms.allowscalar(device) do
for j in axes(a, 2)
∂b_∂i = unthreaded_deriv_with_respect_to_i(device, b, j)
∂²b_∂i² = unthreaded_deriv_with_respect_to_i(device, ∂b_∂i)
∂³b_∂i³ = unthreaded_deriv_with_respect_to_i(device, ∂²b_∂i²)
for i in axes(a, 1)
a[i, j] = ∂³b_∂i³[i]
end
end
end
a_threaded = AT(rand(100, 100))
a_unthreaded = AT(rand(100, 100))
b = AT(rand(100, 100))
is_single_cpu_thread =
device isa ClimaComms.CPUSingleThreaded &&
context isa ClimaComms.SingletonCommsContext
threaded_deriv3_with_respect_to_i!(device, a_threaded, b)
unthreaded_deriv3_with_respect_to_i!(device, a_unthreaded, b)
@test a_threaded == a_unthreaded
# TODO: Figure out source of allocations for interdependent iterators.
threaded_allocations =
@allocated threaded_deriv3_with_respect_to_i!(device, a_threaded, b)
@info "Allocated $threaded_allocations bytes"
is_single_cpu_thread && @test_broken threaded_allocations == 0
end

This last test does run, but the the a_threaded matrix is never modified. The `an_unthreaded matrix is indeed modified. This is verifiable by considering the following local modification to the test:

FT = Float32
a_threaded = AT(rand(FT, 100, 100))
a_unthreaded = AT(rand(FT, 100, 100))
b = AT(rand(FT, 100, 100))

a_threaded_backup = copy(a_threaded)
a_unthreaded_backup = copy(a_unthreaded)

is_single_cpu_thread =
    device isa ClimaComms.CPUSingleThreaded &&
    context isa ClimaComms.SingletonCommsContext

threaded_deriv3_with_respect_to_i!(device, a_threaded, b)
unthreaded_deriv3_with_respect_to_i!(device, a_unthreaded, b)
@test_broken a_threaded == a_unthreaded
@test a_threaded == a_threaded_backup
@test a_unthreaded != a_unthreaded_backup

@haakon-e haakon-e changed the base branch from dy/gpu_threaded to he/fix/add-FT-to-gpu-tests May 30, 2025 22:55
@charleskawczynski
Copy link
Copy Markdown
Member

Sorry for not reviewing this earlier. I'm just not sure that using (and spreading the use of) @threaded is a wise decision. Please see this comment. Perhaps most importantly, it is not clear to me that this pattern will work with respect to highly fused kernels with lots of broadcasted objects, which may be needed to fix CliMA/ClimaAtmos.jl#3860 unless the idea is that we are planning to no longer use ClimaCore's operators, but I imagine that that decision should be a separate and thorough discussion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants